TensorRT-LLM × PyTorch: A New Development Paradigm for High-Performance LLM Inference

加速计算专家团队 薛博阳

目录

议程 (Agenda)

引言 (Introduction)

快速入门 (Quick Start)

LLM API

用户可以通过几行代码尝试 PyTorch 工作流。

# Import LLM API
from tensorrt_llm.torch import LLM

# Create a LLM object
llm = LLM(model="./Llama-3.1-8B-Instruct")

# Prepare prompts
prompts = [
    "Hi, pls tell me something about reasoning model",
    "Hi, pls tell me something about TensorRT-LLM"
]

# Generate output
output = llm.generate(prompts)

PyTorch workflow

* LLM API 的设计思路借鉴自 vLLM 团队。

更多示例和参数
更多带有附加参数的示例可在 examples/pytorch/quickstart_advanced.py 中找到。

概览 (Overview)

下图展示了 TensorRT-LLM 中 TensorRT 工作流和新的 PyTorch 工作流的整体架构。

TensorRT-LLM 与 PyTorch 工作流概览 (Page 5)
TensorRT-LLM 与 PyTorch 工作流概览 (Page 5)

两个工作流共享相同的上层 Serving 和 API 接口,但在 Runtime 和 Modeling 层面有所不同。

高亮 API 和 Runtime 部分的架构概览 (Page 6)
高亮 API 和 Runtime 部分的架构概览 (Page 6)

LLM API 对比

版本1:传统 TensorRT 工作流

此工作流需要手动进行模型转换和引擎构建,并使用 Python 包装器调用 C++ 运行时来构建新模型。

传统 TensorRT 工作流示意图 (Page 7)
传统 TensorRT 工作流示意图 (Page 7)

版本2:LLM API TensorRT 工作流

这是一个基于 LLM API 的单步工作流,底层使用 TensorRT。它通过 Python 包装器简化了新模型的构建,并使用带有 Python 绑定的 C++ 运行时。

LLM API TensorRT 工作流示意图 (Page 8)
LLM API TensorRT 工作流示意图 (Page 8)

版本3:LLM API PyTorch 工作流

这是基于 LLM API 的单步工作流,底层使用 PyTorch。它采用基于 PyTorch 的模型 API 来构建新模型,并通过重用模块化的 C++ 运行时来执行。

三种工作流对比示意图 (Page 9)
三种工作流对比示意图 (Page 9)

PyTorch 工作流架构详解

PyTorch 工作流专注于易用性和灵活性,其路径如下图高亮部分所示:

高亮 PyTorch 工作流的架构概览 (Page 12)
高亮 PyTorch 工作流的架构概览 (Page 12)

该流程从 LLM (Torch) API 开始,通过 PyExecutor 和模块化的运行时接口,调度 PyTorch Engine 执行。模型层 (torch.nn.Module) 可以使用 PyTorch 原生算子、自定义算子以及复用底层的 TRT-LLM Kernels。

代码结构 (Code Structure)

TensorRT-LLM 的代码结构清晰地划分了不同功能模块。

API

llm.py 模块继承了 LLM API,是 PyTorch 工作流的用户入口。

代码结构 - API 部分 (Page 13)
代码结构 - API 部分 (Page 13)

Runtime

pyexecutor/ 目录包含了 Python 运行时的实现。

代码结构 - Runtime 部分 (Page 14)
代码结构 - Runtime 部分 (Page 14)

Modeling

模型定义相关代码位于多个模块中,包括:
* attention_backend/: 实现了多种 Attention 后端,如 Vanilla, flashinfer, TRT-LLM, StarAttention。
* models/: 使用 PyTorch 模块实现各种模型。
* modules/: 包含构成模型的基本 PyTorch 模块,如 Linear, Norm, Attention, MLP, MoE 等。

代码结构 - Modeling 部分 (Page 15)
代码结构 - Modeling 部分 (Page 15)

完整代码结构概览

基于 PyTorch 的建模 (PyTorch based Modeling)

下图展示了基于 PyTorch 的建模在整个系统架构中的位置。它位于底层,负责模型的定义和执行,并与上层的 Python 运行时、C++ 运行时以及服务层(如 Triton Inference Server)进行交互。

Page 17
Page 17

整个流程分为以下几个层次:

使用 PyTorch 开发模型

添加新模型

模型层次结构

下图展示了模型的层次结构:

Page 19
Page 19

从外到内依次是:
1. PyTorchModelEngine
2. DecoderModelForCausalLM
3. LMHead
4. DecoderModel
* Embedding
* RMSNorm
* DecoderLayer x N
* RMSNorm
* Attention
* MLP

文档链接: https://nvidia.github.io/TensorRT-LLM/torch/adding_new_model.html

张量 (Tensors)

数据在神经网络中的表示

input_tensor: tensorrt_llm.Tensor
# Slicing (results in a new Tensor)
sliced_tensor = slice(input_tensor, starts=[1, 0], sizes=[2, 2])
# Indexing
indices = constant(np.array([0, 2], dtype=np.int32))
gathered_tensor = gather(input_tensor, dim=0, indices=indices)
# Boolean masking
mask = gt(input_tensor, 5)
masked_tensor = masked_select(input_tensor, mask)
# Unary Op
abs_tensor = input_tensor.abs() # Does not support abs()

* 每个操作都必须产生新的 Tensors。 * 这依赖于图优化来高效执行。

input_tensor: torch.Tensor
# Slicing (creates a view, materialized when needed)
sliced_tensor = input_tensor[1:3, 0:2]
# Indexing
indexed_tensor = input_tensor[[0, 2]]
# Boolean masking
masked_tensor = input_tensor[input_tensor > 5]
# Unary Op generating a new Tensor
abs_tensor = input_tensor.abs()
# Unary Op with in-place modification
abs_tensor = input_tensor.abs_()

* 在可能的情况下创建张量的“视图”(views),仅在需要时物化新的张量。 * 命令式编程更加自然。

函数 (Functionals)

TensorRT 和 PyTorch 中的内置操作

def softmax(input: Tensor, dim: Optional[int] = None) -> Tensor:
    axes = dim_to_trt_axes(dim)
    layer = default_trtnet().add_softmax(input.trt_tensor)
    layer.axes = axes
    return_create_tensor(layer.get_output(0), layer)
def softmax(input: torch.Tensor, dim: Optional[int] = None) -> torch.Tensor:
    return F.softmax(input, dim=dim)

使用自定义核函数 (Custom Kernel)

实现 TensorRT 插件 / PyTorch 操作

调用 TensorRT 插件 / PyTorch 操作

模块 (Modules)

模型的构建块

class RmsNorm(Module):
    def __init__(self, ...):
        # ...
        if self.elementwise_affine:
            self.weight = Parameter(shape=self.normalized_shape, dtype=dtype)
        else:
            self.register_parameter('weight', None)
        
        self.eps = eps
        self.dtype = dtype

    def forward(self, x, ...):
        weight = None if self.weight is None else self.weight.value
        if self.normalized_shape is None:
            normalized_shape = self.normalized_shape
        return rms_norm(x, normalized_shape, self.num_groups, weight, self.eps)
class RmsNorm(nn.Module):
    def __init__(self, ...):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size, dtype=type))
        self.variance_epsilon = eps

    def forward(self, 
                hidden_states: torch.Tensor,
                residual: Optional[torch.Tensor] = None
               ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        if IS_FLASHINFER_AVAILABLE:
            from ...custom_op import flashinfer_fused_add_rmsnorm
            # ...
            if residual is not None:
                flashinfer_fused_add_rmsnorm(hidden_states, residual,
                                             self.weight, self.variance_epsilon)
            return hidden_states, residual
        # ...

调试体验

output = self.proj(inter)

self.register_network_output('mlp_output', output)

* 使用标志构建 TensorRT 引擎。

trtllm-build ... --enable_debug_output

* 在 \1 中为调试模式开启。

runner_kwangs = dict(debug_mode=True, ) # ...
model_runner = ModelRunner.from_dir(**runner_kwags)

* 在 \1 中捕获输出张量。

if self.debug_mode:
    #...
    print(self.debug_buffer['transformer.layers.0.mlp_output'])
output = self.down_proj(inter)
print(output.shape, output[0])

* 可以在 IDE 中使用断点。

内置模块 (Built-in Modules)

Page 26
Page 26

Linear (线性层)

用于 GEMM 的统一线性模块

Linear - 权重加载 (Weight Loading)

Page 28
Page 28

扩展模块:在线性层中支持 FP4

Page 29
Page 29

_create_weights (由 load_weights 调用) 中声明模块的参数。

class Linear(nn.Module):
    def _create_weights(self):
        # Quantized weights
        self.weight = Parameter(torch.empty(
            [self.out_features, self.in_features // 2],
            dtype=fp4_utils.float4_e2m1x2,
            device=device),
                                requires_grad=False)

        # FP8 per-block scaling factors
        self.weight_scale = # ...
        
        # FP32 per-tensor global scaling factor = 448x6 / amax_input
        self.input_scale = # ...
        self.inv_input_scale = # ...

        # (amax_input*amax_weight) / (448*6*448*6)
        self.alpha = # ...

        self.profiler = torch.classes.trtllm.FP4GemmRunner.get_instance(
            self.dtype)
        self.needs_profiling = True

正确加载权重缩放因子
- 例如,对于FP4模型,model.layers.0.self_attn.qkv_proj也会接收到:
- model.layers.0.self_attn.q_proj.weight_scale

- `model.layers.0.self_attn.q_proj.weight_scale_2`
- `model.layers.0.self_attn.q_proj.input_scale`
- 这同样适用于 k_proj 和 v_proj。

代码示例:加载和处理FP4权重缩放因子
代码片段(Page 31)展示了load_weight_scales_nvfp4函数,它处理权重的加载,并在拼接后对权重缩放因子进行重排。

实现 apply_linear
- 动态激活量化
- 提供静态全局SF和BF16激活。
- 返回FP4量化的激活和FP8块状SFs。

代码示例:实现apply_linear以支持FP4
代码片段(Page 32)展示了apply_linear函数的实现。该函数首先进行性能分析,然后使用torch.ops.triton.fp4_quantize对激活进行量化,并最终调用run_gemm执行GEMM操作。

Attention 后端

class AttentionBackend(Generic[Metadata]):
    def forward(self,
                q: torch.Tensor,
                k: Optional[torch.Tensor],
                v: Optional[torch.Tensor],
                metadata: TMetadata,
                *,
                attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
                **kwargs) -> torch.Tensor:

添加新模型:以Qwen3为例

代码示例:初始化Qwen3模型和Attention模块
代码片段(Page 34)展示了Qwen3MoeQwen3Attention类的__init__方法,演示了如何使用模型配置来初始化模型组件。

代码示例:Qwen3解码器层和模型的`forward`方法
代码片段(Page 35)展示了Qwen3MoeDecoderLayerQwen3MoeModelforward方法,说明了数据在模型层级间的流动过程。

代码示例:Qwen3因果语言模型的`__init__`和`forward`方法
代码片段(Page 36)展示了Qwen3MoeForCausalLM的实现,它将Qwen3MoeModelLogitsProcessor结合起来,用于因果语言建模任务。

模型权重加载

class DecoderModelForCausalLM(nn.Module,
    def load_weights(self,
                     params_map = {
                         'qkv_proj': ['q_proj', 'k_proj', 'v_proj'],
                         'gate_up_proj': ['gate_proj', 'up_proj']
                     }

建模代码极大简化:以Qwen为例


注意:这并非严格的苹果对苹果比较。例如,1200行的convert.py包含了一些TensorRT工作流不再使用的遗留代码。这只是为了展示PyTorch工作流的代码库更加清晰。

模块化 Python 运行时 (Modularized Python Runtime)

下图展示了运行时在整个系统架构中的位置,重点突出了Python运行时及其与C++组件的交互。

PyTorch 建模架构图
上图(Page 41)展示了整个系统的架构。请求从Triton推理服务器或OpenAI服务器进入,可能通过Dynamo,最终到达基于PyTorch的LLM。系统分为服务层、API层、运行时层和建模层。运行时层包含Python实现的GenerationExecutorPyExecutor,并与C++实现的调度器和KV缓存管理器交互。建模层基于torch.nn.Module,并利用PyTorch原生及自定义操作,底层调用TRT-LLM核。

Python 运行时概览

Python 运行时概览图
上图(Page 42)详细展示了Python运行时的组件结构。PyExecutor是顶层组件。模块化的Python运行时接口定义了ModelEngine, RequestScheduler, BaseResourceManagerDecoder等核心抽象。Python运行时实现了这些接口,例如PyTorchModelEngineSimpleScheduler。底层则调用C++实现的组件,如tensorrt_llm::batch_manager中的调度器和KV缓存管理器。

模块化的运行时模块

执行器循环 (Executor Loop)

CPU/GPU 重叠执行示意图
上图(Page 44)展示了CPU和GPU在执行器循环中的工作流。通过重叠CPU任务(准备、处理)和GPU任务(计算、采样),可以有效隐藏CPU开销,提升整体效率。

*重叠调度器的想法归功于SGLang团队: https://lmsys.org/blog/2024-12-04-sglang-v0-4/

CUDA Graph

基于 PyTorch 工作流的端到端示例:DeepSeek R1 性能优化

性能 (Performance)

Qwen3-235B-A22B FP8

性能图:TRT-LLM 8xH20 1k-1k 性能
上图(Page 39)展示了在8个H20 GPU上,输入和输出序列长度均为1k时的性能表现。比较了DP8EP8(数据并行)和TP8EP8(张量并行)两种配置下的"每GPU输出吞吐量"与"每用户输出吞吐量"的关系。

Commit: a4c3359513dae5694a2a01955abffb7702b004ab
*仅用于技术讨论

结论

行动号召 (Call for actions)

社区与资源

Page 49
Page 49

GitHub仓库:
https://github.com/NVIDIA/TensorRT-LLM

加入 NVIDIA 开发者 Discord 社区